#include <lc/bt.h>
#include <algorithm>
#include <iostream>
using namespace std;

class Solution
{
   private:
    int maxSum = INT_MIN;

   public:
    int maxGain(TreeNode* node)
    {
        if (node == nullptr) {
            return 0;
        }

        // 递归计算左右子节点的最大贡献值
        // 只有在最大贡献值大于 0 时，才会选取对应子节点
        int leftGain = max(maxGain(node->left), 0);
        int rightGain = max(maxGain(node->right), 0);

        // 节点的最大路径和取决于该节点的值与该节点的左右子节点的最大贡献值
        int priceNewpath = node->val + leftGain + rightGain;

        // 更新答案
        maxSum = max(maxSum, priceNewpath);

        // 返回节点的最大贡献值
        return node->val + max(leftGain, rightGain);
    }

    int maxPathSum(TreeNode* root)
    {
        maxGain(root);
        return maxSum;
    }
};

int main(int argc, char* argv[])
{
    Solution s;
    //  -10
    //  / \
    // 9  20
    //   /  \
    //  15   7
    TreeNode* t = constructT("[-10,9,20,null,null,15,7]");
    assert(s.maxPathSum(t) == 42);
    return 0;
}
